{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "417a9e80-e1be-4db4-bfa5-831570a39fe3",
   "metadata": {},
   "source": [
    "# HF Transformers 核心模块学习：Pipelines\n",
    "\n",
    "**Pipelines**（管道）是使用模型进行推理的一种简单易上手的方式。\n",
    "\n",
    "这些管道是抽象了 Transformers 库中大部分复杂代码的对象，提供了一个专门用于多种任务的简单API，包括**命名实体识别、掩码语言建模、情感分析、特征提取和问答**等。\n",
    "\n",
    "\n",
    "| Modality                    | Task                         | Description                                                | Pipeline API                                  |\n",
    "| --------------------------- | ---------------------------- | ---------------------------------------------------------- | --------------------------------------------- |\n",
    "| Audio                       | Audio classification         | 为音频文件分配一个标签                                     | pipeline(task=“audio-classification”)         |\n",
    "|                             | Automatic speech recognition | 将音频文件中的语音提取为文本                               | pipeline(task=“automatic-speech-recognition”) |\n",
    "| Computer vision             | Image classification         | 为图像分配一个标签                                         | pipeline(task=“image-classification”)         |\n",
    "|                             | Object detection             | 预测图像中目标对象的边界框和类别                           | pipeline(task=“object-detection”)             |\n",
    "|                             | Image segmentation           | 为图像中每个独立的像素分配标签（支持语义、全景和实例分割） | pipeline(task=“image-segmentation”)           |\n",
    "| Natural language processing | Text classification          | 为给定的文本序列分配一个标签                               | pipeline(task=“sentiment-analysis”)           |\n",
    "|                             | Token classification         | 为序列里的每个 token 分配一个标签（人, 组织, 地址等等）    | pipeline(task=“ner”)                          |\n",
    "|                             | Question answering           | 通过给定的上下文和问题, 在文本中提取答案                   | pipeline(task=“question-answering”)           |\n",
    "|                             | Summarization                | 为文本序列或文档生成总结                                   | pipeline(task=“summarization”)                |\n",
    "|                             | Translation                  | 将文本从一种语言翻译为另一种语言                           | pipeline(task=“translation”)                  |\n",
    "| Multimodal                  | Document question answering  | 根据给定的文档和问题回答一个关于该文档的问题。             | pipeline(task=“document-question-answering”)  |\n",
    "|                             | Visual Question Answering    | 给定一个图像和一个问题，正确地回答有关图像的问题           | pipeline(task=“vqa”)                          |\n",
    "\n",
    "\n",
    "\n",
    "Pipelines 已支持的完整任务列表：https://huggingface.co/docs/transformers/task_summary\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "418f2e3c-a185-4164-ab8b-3b8c98c3b4a5",
   "metadata": {},
   "source": [
    "## Pipeline API\n",
    "\n",
    "**Pipeline API** 是对所有其他可用管道的包装。它可以像任何其他管道一样实例化，并且降低AI推理的学习和使用成本。\n",
    "\n",
    "![](docs/images/pipeline_func.png)\n",
    "\n",
    "### 使用 Pipeline API 实现 Text Classification 任务\n",
    "\n",
    "\n",
    "**Text classification**(文本分类)与任何模态中的分类任务一样，文本分类将一个文本序列（可以是句子级别、段落或者整篇文章）标记为预定义的类别集合之一。文本分类有许多实际应用，其中包括：\n",
    "\n",
    "- 情感分析：根据某种极性（如积极或消极）对文本进行标记，以在政治、金融和市场等领域支持决策制定。\n",
    "- 内容分类：根据某个主题对文本进行标记，以帮助组织和过滤新闻和社交媒体信息流中的信息（天气、体育、金融等）。\n",
    "\n",
    "\n",
    "下面以 `Text classification` 中的情感分析任务为例，展示如何使用 Pipeline API。\n",
    "\n",
    "模型主页：https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9c38e82-e8f6-4af0-b257-b46f446aa249",
   "metadata": {
    "tags": []
   },
   "source": [
    "## transformers 自定义模型下载的路径\n",
    "\n",
    "在transformers自定义模型下载的路径方法\n",
    "\n",
    "```python\n",
    "import os\n",
    "\n",
    "os.environ['HF_HOME'] = '/mnt/new_volume/hf'\n",
    "os.environ['HF_HUB_CACHE'] = '/mnt/new_volume/hf/hub'\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9684ec5f-9460-4876-9883-69380eacb0e7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No model was supplied, defaulted to distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english).\n",
      "Using a pipeline without specifying a model name and revision in production is not recommended.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'label': 'NEGATIVE', 'score': 0.8957212567329407}]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "# 仅指定任务时，使用默认模型（不推荐）\n",
    "pipe = pipeline(\"sentiment-analysis\")\n",
    "pipe(\"今儿上海可真冷啊\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63a12486-952a-447d-9e90-6b041349a26e",
   "metadata": {},
   "source": [
    "### 测试更多示例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ee67f529-3d32-4834-b29a-e51a65cf4d49",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'label': 'NEGATIVE', 'score': 0.9238728880882263}]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipe(\"我觉得这家店蒜泥白肉的味道一般\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4e104034-94ca-4370-8d35-438501c29ead",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'label': 'NEGATIVE', 'score': 0.8578686118125916}]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 默认使用的模型 distilbert-base-uncased-finetuned-sst-2-english \n",
    "# 并未针对中文做太多训练，中文的文本分类任务表现未必满意\n",
    "pipe(\"你学东西真的好快，理论课一讲就明白了\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9c3d11aa-d523-491c-8910-bdc8aba49487",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'label': 'POSITIVE', 'score': 0.9961802959442139}]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 替换为英文后，文本分类任务的表现立刻改善\n",
    "pipe(\"You learn things really quickly. You understand the theory class as soon as it is taught.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "09df7300-2f0d-4bc4-9572-0631658e8253",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'label': 'NEGATIVE', 'score': 0.9995032548904419}]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipe(\"Today Shanghai is really cold.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "507dbfc3-5347-4b82-8ea1-4e6c3d81c07f",
   "metadata": {},
   "source": [
    "### 批处理调用模型推理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1390396e-1af8-497f-85d9-6c9e984b4609",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'label': 'NEGATIVE', 'score': 0.9995032548904419},\n",
       " {'label': 'NEGATIVE', 'score': 0.9984821677207947},\n",
       " {'label': 'POSITIVE', 'score': 0.9961802959442139}]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text_list = [\n",
    "    \"Today Shanghai is really cold.\",\n",
    "    \"I think the taste of the garlic mashed pork in this store is average.\",\n",
    "    \"You learn things really quickly. You understand the theory class as soon as it is taught.\"\n",
    "]\n",
    "\n",
    "pipe(text_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3686758b-d7f7-4df9-889d-e63af47a138a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "6d5a27fe-87d0-45fc-a31a-9a8db23e290a",
   "metadata": {},
   "source": [
    "## 使用 Pipeline API 调用更多预定义任务\n",
    "\n",
    "## Natural Language Processing(NLP)\n",
    "\n",
    "**NLP**(自然语言处理)任务是最常见的任务类型之一，因为文本是我们进行交流的一种自然方式。要将文本转换为模型可识别的格式，需要对其进行分词。这意味着将一系列文本划分为单独的单词或子词（标记），然后将这些标记转换为数字。结果就是，您可以将一系列文本表示为一系列数字，并且一旦您拥有了一系列数字，它就可以输入到模型中来解决各种NLP任务！\n",
    "\n",
    "上面演示的 文本分类任务，以及接下来的标记、问答等任务都属于 NLP 范畴。\n",
    "\n",
    "### Token Classification\n",
    "\n",
    "在任何NLP任务中，文本都经过预处理，将文本序列分成单个单词或子词。这些被称为tokens。\n",
    "\n",
    "**Token Classification**（Token分类）将每个token分配一个来自预定义类别集的标签。\n",
    "\n",
    "两种常见的 Token 分类是：\n",
    "\n",
    "- 命名实体识别（NER）：根据实体类别（如组织、人员、位置或日期）对token进行标记。NER在生物医学设置中特别受欢迎，可以标记基因、蛋白质和药物名称。\n",
    "- 词性标注（POS）：根据其词性（如名词、动词或形容词）对标记进行标记。POS对于帮助翻译系统了解两个相同的单词如何在语法上不同很有用（作为名词的银行与作为动词的银行）。\n",
    "\n",
    "模型主页：https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4f1ed125-f9ed-42d9-b102-dbc3172baefd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No model was supplied, defaulted to dbmdz/bert-large-cased-finetuned-conll03-english and revision f2482bf (https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english).\n",
      "Using a pipeline without specifying a model name and revision in production is not recommended.\n",
      "Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "classifier = pipeline(task=\"ner\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "68bab77c-0fe6-4781-b978-02d56829db33",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'entity': 'I-ORG', 'score': 0.9968, 'index': 1, 'word': 'Hu', 'start': 0, 'end': 2}\n",
      "{'entity': 'I-ORG', 'score': 0.9293, 'index': 2, 'word': '##gging', 'start': 2, 'end': 7}\n",
      "{'entity': 'I-ORG', 'score': 0.9763, 'index': 3, 'word': 'Face', 'start': 8, 'end': 12}\n",
      "{'entity': 'I-MISC', 'score': 0.9983, 'index': 6, 'word': 'French', 'start': 18, 'end': 24}\n",
      "{'entity': 'I-LOC', 'score': 0.999, 'index': 10, 'word': 'New', 'start': 42, 'end': 45}\n",
      "{'entity': 'I-LOC', 'score': 0.9987, 'index': 11, 'word': 'York', 'start': 46, 'end': 50}\n",
      "{'entity': 'I-LOC', 'score': 0.9992, 'index': 12, 'word': 'City', 'start': 51, 'end': 55}\n"
     ]
    }
   ],
   "source": [
    "preds = classifier(\"Hugging Face is a French company based in New York City.\")\n",
    "preds = [\n",
    "    {\n",
    "        \"entity\": pred[\"entity\"],\n",
    "        \"score\": round(pred[\"score\"], 4),\n",
    "        \"index\": pred[\"index\"],\n",
    "        \"word\": pred[\"word\"],\n",
    "        \"start\": pred[\"start\"],\n",
    "        \"end\": pred[\"end\"],\n",
    "    }\n",
    "    for pred in preds\n",
    "]\n",
    "print(*preds, sep=\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21c72d55-e574-444f-96bc-62704022a148",
   "metadata": {},
   "source": [
    "#### 合并实体"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9b441191-6156-44b8-a323-db461ad06efb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No model was supplied, defaulted to dbmdz/bert-large-cased-finetuned-conll03-english and revision f2482bf (https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english).\n",
      "Using a pipeline without specifying a model name and revision in production is not recommended.\n",
      "Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "/root/miniconda3/lib/python3.11/site-packages/transformers/pipelines/token_classification.py:169: UserWarning: `grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy=\"AggregationStrategy.SIMPLE\"` instead.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'entity_group': 'ORG',\n",
       "  'score': 0.9674639,\n",
       "  'word': 'Hugging Face',\n",
       "  'start': 0,\n",
       "  'end': 12},\n",
       " {'entity_group': 'MISC',\n",
       "  'score': 0.99828726,\n",
       "  'word': 'French',\n",
       "  'start': 18,\n",
       "  'end': 24},\n",
       " {'entity_group': 'LOC',\n",
       "  'score': 0.99896103,\n",
       "  'word': 'New York City',\n",
       "  'start': 42,\n",
       "  'end': 55}]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classifier = pipeline(task=\"ner\", grouped_entities=True)\n",
    "classifier(\"Hugging Face is a French company based in New York City.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7736791b-9585-4534-9efe-3fae3b7b6ca8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "369dda97-bd1a-4cb0-9636-47c2308c6289",
   "metadata": {},
   "source": [
    "### Question Answering\n",
    "\n",
    "**Question Answering**(问答)是另一个token-level的任务，返回一个问题的答案，有时带有上下文（开放领域），有时不带上下文（封闭领域）。每当我们向虚拟助手提出问题时，例如询问一家餐厅是否营业，就会发生这种情况。它还可以提供客户或技术支持，并帮助搜索引擎检索您要求的相关信息。\n",
    "\n",
    "有两种常见的问答类型：\n",
    "\n",
    "- 提取式：给定一个问题和一些上下文，模型必须从上下文中提取出一段文字作为答案\n",
    "- 生成式：给定一个问题和一些上下文，答案是根据上下文生成的；这种方法由`Text2TextGenerationPipeline`处理，而不是下面展示的`QuestionAnsweringPipeline`\n",
    "\n",
    "模型主页：https://huggingface.co/distilbert-base-cased-distilled-squad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a5281b0b-6b57-4884-92e1-cc2677987360",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No model was supplied, defaulted to distilbert-base-cased-distilled-squad and revision 626af31 (https://huggingface.co/distilbert-base-cased-distilled-squad).\n",
      "Using a pipeline without specifying a model name and revision in production is not recommended.\n"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "question_answerer = pipeline(task=\"question-answering\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ca674a51-30a4-4dea-a443-e428925e990f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "score: 0.9327, start: 30, end: 54, answer: huggingface/transformers\n"
     ]
    }
   ],
   "source": [
    "preds = question_answerer(\n",
    "    question=\"What is the name of the repository?\",\n",
    "    context=\"The name of the repository is huggingface/transformers\",\n",
    ")\n",
    "print(\n",
    "    f\"score: {round(preds['score'], 4)}, start: {preds['start']}, end: {preds['end']}, answer: {preds['answer']}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f2158feb-a528-410a-aabf-f3bd7f2726c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "score: 0.9458, start: 115, end: 122, answer: Beijing\n"
     ]
    }
   ],
   "source": [
    "preds = question_answerer(\n",
    "    question=\"What is the capital of China?\",\n",
    "    context=\"On 1 October 1949, CCP Chairman Mao Zedong formally proclaimed the People's Republic of China in Tiananmen Square, Beijing.\",\n",
    ")\n",
    "print(\n",
    "    f\"score: {round(preds['score'], 4)}, start: {preds['start']}, end: {preds['end']}, answer: {preds['answer']}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdb1d7e5-6baa-4142-81c1-5d311c0440e6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eebd7483-99a5-4f2a-894c-13cf5a3b71b2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "d995cb85-ab8f-4b28-9413-1a83fa3e4c4d",
   "metadata": {},
   "source": [
    "### Summarization\n",
    "\n",
    "**Summarization**(文本摘要）从较长的文本中创建一个较短的版本，同时尽可能保留原始文档的大部分含义。摘要是一个序列到序列的任务；它输出比输入更短的文本序列。有许多长篇文档可以进行摘要，以帮助读者快速了解主要要点。法案、法律和财务文件、专利和科学论文等文档可以摘要，以节省读者的时间并作为阅读辅助工具。\n",
    "\n",
    "与问答类似，摘要有两种类型：\n",
    "\n",
    "- 提取式：从原始文本中识别和提取最重要的句子\n",
    "- 生成式：从原始文本中生成目标摘要（可能包括输入文件中没有的新单词）；`SummarizationPipeline`使用生成式方法\n",
    "\n",
    "模型主页：https://huggingface.co/t5-base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8e028f36-2b50-4ae3-8dce-2f0ac685e8fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.11/site-packages/transformers/models/t5/tokenization_t5_fast.py:160: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
      "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
      "- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n",
      "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
      "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "summarizer = pipeline(task=\"summarization\",\n",
    "                      model=\"t5-base\",\n",
    "                      min_length=8,\n",
    "                      max_length=32,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "8c760bbf-2ae4-4f84-bd5e-32c30ee6bd78",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'summary_text': 'the Transformer is the first sequence transduction model based entirely on attention . it replaces recurrent layers commonly used in encoder-decode'}]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summarizer(\n",
    "    \"\"\"\n",
    "    In this work, we presented the Transformer, the first sequence transduction model based entirely on attention, \n",
    "    replacing the recurrent layers most commonly used in encoder-decoder architectures with multi-headed self-attention. \n",
    "    For translation tasks, the Transformer can be trained significantly faster than architectures based on recurrent or convolutional layers. \n",
    "    On both WMT 2014 English-to-German and WMT 2014 English-to-French translation tasks, we achieve a new state of the art. \n",
    "    In the former task our best model outperforms even all previously reported ensembles.\n",
    "    \"\"\"\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b50ff52b-e61e-40cd-ab66-a591dc0c3b0d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'summary_text': 'large language models (LLMs) are very large deep learning models pre-trained on vast amounts of data . transformers are capable of un'}]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summarizer(\n",
    "    '''\n",
    "    Large language models (LLM) are very large deep learning models that are pre-trained on vast amounts of data. \n",
    "    The underlying transformer is a set of neural networks that consist of an encoder and a decoder with self-attention capabilities. \n",
    "    The encoder and decoder extract meanings from a sequence of text and understand the relationships between words and phrases in it.\n",
    "    Transformer LLMs are capable of unsupervised training, although a more precise explanation is that transformers perform self-learning. \n",
    "    It is through this process that transformers learn to understand basic grammar, languages, and knowledge.\n",
    "    Unlike earlier recurrent neural networks (RNN) that sequentially process inputs, transformers process entire sequences in parallel. \n",
    "    This allows the data scientists to use GPUs for training transformer-based LLMs, significantly reducing the training time.\n",
    "    '''\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72315144-7fae-4848-af79-a70e428b2416",
   "metadata": {},
   "source": [
    "\n",
    "## Audio 音频处理任务\n",
    "\n",
    "音频和语音处理任务与其他模态略有不同，主要是因为音频作为输入是一个连续的信号。与文本不同，原始音频波形不能像句子可以被划分为单词那样被整齐地分割成离散的块。为了解决这个问题，通常在固定的时间间隔内对原始音频信号进行采样。如果在每个时间间隔内采样更多样本，采样率就会更高，音频更接近原始音频源。\n",
    "\n",
    "以前的方法是预处理音频以从中提取有用的特征。现在更常见的做法是直接将原始音频波形输入到特征编码器中，以提取音频表示。这样可以简化预处理步骤，并允许模型学习最重要的特征。\n",
    "\n",
    "### Audio classification\n",
    "\n",
    "**Audio classification**(音频分类)是一项将音频数据从预定义的类别集合中进行标记的任务。这是一个广泛的类别，具有许多具体的应用，其中一些包括：\n",
    "\n",
    "- 声学场景分类：使用场景标签（“办公室”、“海滩”、“体育场”）对音频进行标记。\n",
    "- 声学事件检测：使用声音事件标签（“汽车喇叭声”、“鲸鱼叫声”、“玻璃破碎声”）对音频进行标记。\n",
    "- 标记：对包含多种声音的音频进行标记（鸟鸣、会议中的说话人识别）。\n",
    "- 音乐分类：使用流派标签（“金属”、“嘻哈”、“乡村”）对音乐进行标记。\n",
    "\n",
    "模型主页：https://huggingface.co/superb/hubert-base-superb-er\n",
    "\n",
    "数据集主页：https://huggingface.co/datasets/superb#er\n",
    "\n",
    "```\n",
    "情感识别（ER）为每个话语预测一个情感类别。我们采用了最广泛使用的ER数据集IEMOCAP，并遵循传统的评估协议：我们删除不平衡的情感类别，只保留最后四个具有相似数量数据点的类别，并在标准分割的五折交叉验证上进行评估。评估指标是准确率（ACC）。\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4d4c1e9-8a67-4c72-8dce-ab326e0bc3b6",
   "metadata": {},
   "source": [
    "#### 前置依赖包安装\n",
    "\n",
    "建议在命令行安装必要的音频数据处理包: ffmpeg\n",
    "\n",
    "```shell\n",
    "$apt update & apt upgrade\n",
    "$apt install -y ffmpeg\n",
    "$pip install ffmpeg ffmpeg-python\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2e1f8a8d-75eb-49ab-9353-6c9d1f384cac",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at superb/hubert-base-superb-er were not used when initializing HubertForSequenceClassification: ['hubert.encoder.pos_conv_embed.conv.weight_v', 'hubert.encoder.pos_conv_embed.conv.weight_g']\n",
      "- This IS expected if you are initializing HubertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing HubertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of HubertForSequenceClassification were not initialized from the model checkpoint at superb/hubert-base-superb-er and are newly initialized: ['hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "classifier = pipeline(task=\"audio-classification\", model=\"superb/hubert-base-superb-er\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "204a4159-d14d-4623-8340-93d15abcc549",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'score': 0.4532, 'label': 'hap'},\n",
       " {'score': 0.3622, 'label': 'sad'},\n",
       " {'score': 0.0943, 'label': 'neu'},\n",
       " {'score': 0.0903, 'label': 'ang'}]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用 Hugging Face Datasets 上的测试文件\n",
    "preds = classifier(\"https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac\")\n",
    "preds = [{\"score\": round(pred[\"score\"], 4), \"label\": pred[\"label\"]} for pred in preds]\n",
    "preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "2528bf81-8289-4bb9-bf2d-c0ce9f7e8b36",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'score': 0.4532, 'label': 'hap'},\n",
       " {'score': 0.3622, 'label': 'sad'},\n",
       " {'score': 0.0943, 'label': 'neu'},\n",
       " {'score': 0.0903, 'label': 'ang'}]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用本地的音频文件做测试\n",
    "preds = classifier(\"data/audio/mlk.flac\")\n",
    "preds = [{\"score\": round(pred[\"score\"], 4), \"label\": pred[\"label\"]} for pred in preds]\n",
    "preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6826a83-f5a5-454e-ac16-1f215abef861",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e073cd1f-ba16-40a7-a029-bbd5e2a53dee",
   "metadata": {},
   "source": [
    "### Automatic speech recognition（ASR）\n",
    "\n",
    "**Automatic speech recognition**（自动语音识别）将语音转录为文本。这是最常见的音频任务之一，部分原因是因为语音是人类交流的自然形式。如今，ASR系统嵌入在智能技术产品中，如扬声器、电话和汽车。我们可以要求虚拟助手播放音乐、设置提醒和告诉我们天气。\n",
    "\n",
    "但是，Transformer架构帮助解决的一个关键挑战是低资源语言。通过在大量语音数据上进行预训练，仅在一个低资源语言的一小时标记语音数据上进行微调，仍然可以产生与以前在100倍更多标记数据上训练的ASR系统相比高质量的结果。\n",
    "\n",
    "模型主页：https://huggingface.co/openai/whisper-small\n",
    "\n",
    "下面展示使用 `OpenAI Whisper Small` 模型实现 ASR 的 Pipeline API 示例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "68d14f8c-1571-4889-abd4-8fde09dc610a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "# 使用 `model` 参数指定模型\n",
    "transcriber = pipeline(task=\"automatic-speech-recognition\", model=\"openai/whisper-small\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "316ebe18-9c8e-4ded-96b7-751304aa9122",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its creed.'}"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text = transcriber(\"data/audio/mlk.flac\")\n",
    "text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "192d5e06-ac4a-4a56-be36-9e445751bda6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "228b1482-1e13-4557-8d37-5d56e961b5c5",
   "metadata": {},
   "source": [
    "## Computer Vision 计算机视觉\n",
    "\n",
    "**Computer Vision**（计算机视觉）任务中最早成功之一是使用卷积神经网络（CNN）识别邮政编码数字图像。图像由像素组成，每个像素都有一个数值。这使得将图像表示为像素值矩阵变得容易。每个像素值组合描述了图像的颜色。\n",
    "\n",
    "计算机视觉任务可以通过以下两种通用方式解决：\n",
    "\n",
    "- 使用卷积来学习图像的层次特征，从低级特征到高级抽象特征。\n",
    "- 将图像分成块，并使用Transformer逐步学习每个图像块如何相互关联以形成图像。与CNN偏好的自底向上方法不同，这种方法有点像从一个模糊的图像开始，然后逐渐将其聚焦清晰。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf0bf9f4-6413-409b-968e-f03ca0c88367",
   "metadata": {},
   "source": [
    "### Image Classificaiton\n",
    "\n",
    "**Image Classificaiton**(图像分类)将整个图像从预定义的类别集合中进行标记。像大多数分类任务一样，图像分类有许多实际用例，其中一些包括：\n",
    "\n",
    "- 医疗保健：标记医学图像以检测疾病或监测患者健康状况\n",
    "- 环境：标记卫星图像以监测森林砍伐、提供野外管理信息或检测野火\n",
    "- 农业：标记农作物图像以监测植物健康或用于土地使用监测的卫星图像\n",
    "- 生态学：标记动物或植物物种的图像以监测野生动物种群或跟踪濒危物种\n",
    "\n",
    "模型主页：https://huggingface.co/google/vit-base-patch16-224"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "98ca9736-10b9-45b0-a3e3-925f153bab57",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No model was supplied, defaulted to google/vit-base-patch16-224 and revision 5dca96d (https://huggingface.co/google/vit-base-patch16-224).\n",
      "Using a pipeline without specifying a model name and revision in production is not recommended.\n"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "classifier = pipeline(task=\"image-classification\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "d8a41647-e8db-4d29-9dac-06c3c145acaa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'score': 0.4335, 'label': 'lynx, catamount'}\n",
      "{'score': 0.0348, 'label': 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor'}\n",
      "{'score': 0.0324, 'label': 'snow leopard, ounce, Panthera uncia'}\n",
      "{'score': 0.0239, 'label': 'Egyptian cat'}\n",
      "{'score': 0.0229, 'label': 'tiger cat'}\n"
     ]
    }
   ],
   "source": [
    "preds = classifier(\n",
    "    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg\"\n",
    ")\n",
    "preds = [{\"score\": round(pred[\"score\"], 4), \"label\": pred[\"label\"]} for pred in preds]\n",
    "print(*preds, sep=\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "662ecb60-ebb9-4f4e-82bc-21e79bb22d90",
   "metadata": {},
   "source": [
    "![](data/image/cat-chonk.jpeg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "67592129-6430-4349-8016-7dde511ff69b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'score': 0.4335, 'label': 'lynx, catamount'}\n",
      "{'score': 0.0348, 'label': 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor'}\n",
      "{'score': 0.0324, 'label': 'snow leopard, ounce, Panthera uncia'}\n",
      "{'score': 0.0239, 'label': 'Egyptian cat'}\n",
      "{'score': 0.0229, 'label': 'tiger cat'}\n"
     ]
    }
   ],
   "source": [
    "# 使用本地图片（狼猫）\n",
    "preds = classifier(\n",
    "    \"data/image/cat-chonk.jpeg\"\n",
    ")\n",
    "preds = [{\"score\": round(pred[\"score\"], 4), \"label\": pred[\"label\"]} for pred in preds]\n",
    "print(*preds, sep=\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8113c251-a648-4ea2-baa4-3081bb490c70",
   "metadata": {},
   "source": [
    "![](data/image/panda.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "a7e9b51a-d759-4398-9894-f10933dbed47",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'score': 0.9962, 'label': 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca'}\n",
      "{'score': 0.0018, 'label': 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens'}\n",
      "{'score': 0.0002, 'label': 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus'}\n",
      "{'score': 0.0001, 'label': 'sloth bear, Melursus ursinus, Ursus ursinus'}\n",
      "{'score': 0.0001, 'label': 'brown bear, bruin, Ursus arctos'}\n"
     ]
    }
   ],
   "source": [
    "# 使用本地图片（熊猫）\n",
    "preds = classifier(\n",
    "    \"data/image/panda.jpg\"\n",
    ")\n",
    "preds = [{\"score\": round(pred[\"score\"], 4), \"label\": pred[\"label\"]} for pred in preds]\n",
    "print(*preds, sep=\"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "120f5fdb-2803-4f4f-9008-00dee9a6a557",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "8118a561-3dc9-4693-8c3b-3e29c37b4d71",
   "metadata": {},
   "source": [
    "### Object Detection\n",
    "\n",
    "与图像分类不同，目标检测在图像中识别多个对象以及这些对象在图像中的位置（由边界框定义）。目标检测的一些示例应用包括：\n",
    "\n",
    "- 自动驾驶车辆：检测日常交通对象，如其他车辆、行人和红绿灯\n",
    "- 遥感：灾害监测、城市规划和天气预报\n",
    "- 缺陷检测：检测建筑物中的裂缝或结构损坏，以及制造业产品缺陷\n",
    "\n",
    "模型主页：https://huggingface.co/facebook/detr-resnet-50"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7122f04-7add-4623-8f3f-ccc00247524b",
   "metadata": {},
   "source": [
    "#### 前置依赖包安装"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "8cbc4227-cbcc-4f18-bf26-66535e66afb6",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: timm in /root/miniconda3/lib/python3.11/site-packages (0.9.12)\n",
      "Requirement already satisfied: torch>=1.7 in /root/miniconda3/lib/python3.11/site-packages (from timm) (2.3.0.dev20240116+cu121)\n",
      "Requirement already satisfied: torchvision in /root/miniconda3/lib/python3.11/site-packages (from timm) (0.18.0.dev20240117+cu121)\n",
      "Requirement already satisfied: pyyaml in /root/miniconda3/lib/python3.11/site-packages (from timm) (6.0.1)\n",
      "Requirement already satisfied: huggingface-hub in /root/miniconda3/lib/python3.11/site-packages (from timm) (0.20.1)\n",
      "Requirement already satisfied: safetensors in /root/miniconda3/lib/python3.11/site-packages (from timm) (0.4.0)\n",
      "Requirement already satisfied: filelock in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (3.9.0)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (4.8.0)\n",
      "Requirement already satisfied: sympy in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (1.11.1)\n",
      "Requirement already satisfied: networkx in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (3.0rc1)\n",
      "Requirement already satisfied: jinja2 in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (2023.4.0)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (12.1.105)\n",
      "Requirement already satisfied: pytorch-triton==2.2.0+e28a256d71 in /root/miniconda3/lib/python3.11/site-packages (from torch>=1.7->timm) (2.2.0+e28a256d71)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /root/miniconda3/lib/python3.11/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.7->timm) (12.1.105)\n",
      "Collecting fsspec (from torch>=1.7->timm)\n",
      "  Using cached fsspec-2023.12.2-py3-none-any.whl.metadata (6.8 kB)\n",
      "Requirement already satisfied: requests in /root/miniconda3/lib/python3.11/site-packages (from huggingface-hub->timm) (2.28.1)\n",
      "Requirement already satisfied: tqdm>=4.42.1 in /root/miniconda3/lib/python3.11/site-packages (from huggingface-hub->timm) (4.66.1)\n",
      "Requirement already satisfied: packaging>=20.9 in /root/miniconda3/lib/python3.11/site-packages (from huggingface-hub->timm) (23.2)\n",
      "Requirement already satisfied: numpy in /root/miniconda3/lib/python3.11/site-packages (from torchvision->timm) (1.24.1)\n",
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /root/miniconda3/lib/python3.11/site-packages (from torchvision->timm) (9.3.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /root/miniconda3/lib/python3.11/site-packages (from jinja2->torch>=1.7->timm) (2.1.3)\n",
      "Requirement already satisfied: charset-normalizer<3,>=2 in /root/miniconda3/lib/python3.11/site-packages (from requests->huggingface-hub->timm) (2.1.1)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /root/miniconda3/lib/python3.11/site-packages (from requests->huggingface-hub->timm) (3.4)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /root/miniconda3/lib/python3.11/site-packages (from requests->huggingface-hub->timm) (1.26.13)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /root/miniconda3/lib/python3.11/site-packages (from requests->huggingface-hub->timm) (2022.12.7)\n",
      "Requirement already satisfied: mpmath>=0.19 in /root/miniconda3/lib/python3.11/site-packages (from sympy->torch>=1.7->timm) (1.2.1)\n",
      "Using cached fsspec-2023.12.2-py3-none-any.whl (168 kB)\n",
      "Installing collected packages: fsspec\n",
      "  Attempting uninstall: fsspec\n",
      "    Found existing installation: fsspec 2023.4.0\n",
      "    Uninstalling fsspec-2023.4.0:\n",
      "      Successfully uninstalled fsspec-2023.4.0\n",
      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
      "datasets 2.16.1 requires fsspec[http]<=2023.10.0,>=2023.1.0, but you have fsspec 2023.12.2 which is incompatible.\u001b[0m\u001b[31m\n",
      "\u001b[0mSuccessfully installed fsspec-2023.12.2\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip install timm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "c3401126-756b-4394-930c-c833c1f574b0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No model was supplied, defaulted to facebook/detr-resnet-50 and revision 2729413 (https://huggingface.co/facebook/detr-resnet-50).\n",
      "Using a pipeline without specifying a model name and revision in production is not recommended.\n",
      "/root/miniconda3/lib/python3.11/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: 'libpng16.so.16: cannot open shared object file: No such file or directory'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?\n",
      "  warn(\n",
      "Some weights of the model checkpoint at facebook/detr-resnet-50 were not used when initializing DetrForObjectDetection: ['model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer1.0.downsample.1.num_batches_tracked']\n",
      "- This IS expected if you are initializing DetrForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing DetrForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.\n",
      "The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.\n"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "detector = pipeline(task=\"object-detection\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "32dc65d4-3868-4d10-a433-9f50832e8e68",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'score': 0.9864,\n",
       "  'label': 'cat',\n",
       "  'box': {'xmin': 178, 'ymin': 154, 'xmax': 882, 'ymax': 598}}]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds = detector(\n",
    "    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg\"\n",
    ")\n",
    "preds = [{\"score\": round(pred[\"score\"], 4), \"label\": pred[\"label\"], \"box\": pred[\"box\"]} for pred in preds]\n",
    "preds"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "550759be-93ab-4988-81f0-c0cb8eccfda8",
   "metadata": {},
   "source": [
    "![](data/image/cat_dog.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "40f66cd7-f2c8-4af3-b9ab-60012792ec1b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'score': 0.9985,\n",
       "  'label': 'cat',\n",
       "  'box': {'xmin': 78, 'ymin': 57, 'xmax': 309, 'ymax': 371}},\n",
       " {'score': 0.989,\n",
       "  'label': 'dog',\n",
       "  'box': {'xmin': 279, 'ymin': 20, 'xmax': 482, 'ymax': 416}}]"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds = detector(\n",
    "    \"data/image/cat_dog.jpg\"\n",
    ")\n",
    "preds = [{\"score\": round(pred[\"score\"], 4), \"label\": pred[\"label\"], \"box\": pred[\"box\"]} for pred in preds]\n",
    "preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d591246-662b-47e0-8531-f862b05ba1ad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1565e1d-20bf-4b19-86c1-dd3a8983a999",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e4993a9f-1032-46ab-97b1-a028d4bd5ebc",
   "metadata": {},
   "source": [
    "### Homework：替换以上示例中的模型，对比不同模型在相同任务上的性能表现\n",
    "\n",
    "在 Hugging Face Models 中找到适合你的模型：https://huggingface.co/models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be3f51d2-863d-44f6-836f-9051351985ae",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
